import pandas as pd
import wandb
from pandas.io.formats.style import Styler


def run_table(run_set):
    map_names = []
    random_starts = []
    shields = []
    eval_shields = []
    avg_unsafe_actions = []
    avg_rew_sum = []

    for r in run_set:
        config = r.config
        summary = r.summary._json_dict

        if "skip_training" not in config:
            continue

        map_names.append(config["grid_world_map_name"])
        random_starts.append(config["randomize_starts"])
        shields.append(config["shield"])
        eval_shields.append(config["evaluation_shield"])
        avg_unsafe_actions.append(summary["test/avg_unsafe_actions"])
        avg_rew_sum.append(summary["test/avg_rew_sum_0"])

    name_map = {
        "centralized": "Centralized",
        "slugs_centralized": "Centralized",
        "decentralized": "Decentralized",
        "slugs_decentralized": "Decentralized",
        "none": "No Shield"
    }

    df = pd.DataFrame({
        "map_name": map_names,
        "shield": shields,
        "eval_shield": eval_shields,
        "safety_violations": avg_unsafe_actions,
        "rew": avg_rew_sum
    })

    df["shield"] = df["shield"].map(name_map)
    df["eval_shield"] = df["eval_shield"].map(name_map)
    df.sort_values(["map_name", "shield", "eval_shield"])
    df = df.groupby(["map_name", "shield", "eval_shield"]).agg({'safety_violations': ["mean"], 'rew': ["mean", "sem"]})
    df["formatted_output"] = df.apply(
        lambda x: f"{x['rew']['mean']:2.1f} $\\pm$ {x['rew']['sem']:1.1f} " +
                  (f"({(x['safety_violations']['mean'] * 100):1.01f})" if x['safety_violations'][
                                                                              'mean'] > 0 else "(0)"),
        axis=1)
    del df["rew"]
    del df["safety_violations"]

    df_full = df.unstack(level="eval_shield")
    df_full.index = df_full.index.rename({"map_name": "Map Name", "shield": "Shield"})
    df_full.columns = df_full.columns.get_level_values("eval_shield").values

    df_diag = pd.concat((df.index.to_frame(), df), axis=1)
    df_diag = df_diag[df_diag["shield"] == df_diag["eval_shield"]]
    del df_diag["shield"]
    del df_diag["eval_shield"]
    del df_diag["map_name"]

    df_diag.index = df_diag.index.rename({"map_name": "Map Name", "shield": "Shield"})
    df_diag = df_diag.reset_index(level=2, drop=True).unstack().rename(columns=name_map)
    df_diag.columns = df_diag.columns.get_level_values(1)

    return df_full, df_diag


def run_table_particle(run_set):
    observabilities = []
    random_starts = []
    shields = []
    eval_shields = []
    avg_unsafe_actions = []
    avg_rew_sum = []

    for r in run_set:
        config = r.config
        summary = r.summary._json_dict

        if "skip_training" not in config:
            continue

        observabilities.append(config["particle_agents_observe_momentum"])
        random_starts.append(config["randomize_starts"])
        shields.append(config["shield"])
        eval_shields.append(config["evaluation_shield"])
        avg_unsafe_actions.append(summary["test/avg_unsafe_actions"])
        avg_rew_sum.append(summary["test/avg_rew_sum_0"])

    name_map = {
        "centralized": "Centralized",
        "slugs_centralized": "Centralized",
        "decentralized": "Decentralized",
        "slugs_decentralized": "Decentralized",
        "none": "No Shield"
    }

    df = pd.DataFrame({
        "observe_momentum": observabilities,
        "random_start": random_starts,
        "shield": shields,
        "eval_shield": eval_shields,
        "safety_violations": avg_unsafe_actions,
        "rew": avg_rew_sum
    })

    df["shield"] = df["shield"].map(name_map)
    df["eval_shield"] = df["eval_shield"].map(name_map)
    df.sort_values(["observe_momentum", "random_start", "shield", "eval_shield"])
    df = df.groupby(["observe_momentum", "random_start", "shield", "eval_shield"]).agg(
        {'safety_violations': ["mean"], 'rew': ["mean", "sem"]})
    df["formatted_output"] = df.apply(
        lambda x: f"{x['rew']['mean']:2.1f} $\\pm$ {x['rew']['sem']:1.1f} " +
                  (f"({(x['safety_violations']['mean'] * 100):1.01f})" if x['safety_violations'][
                                                                              'mean'] > 0 else "(0)"),
        axis=1)
    del df["rew"]
    del df["safety_violations"]

    df.index.set_levels(df.index.levels[0].map({"False": "Partial", "True": "Full"}).rename("Observability"), 0,
                        inplace=True)
    df.index.set_levels(df.index.levels[1].map({"False": "Fixed", "True": "Random"}).rename("Start Type"), 1,
                        inplace=True)
    df.index.set_levels(df.index.levels[2].rename("Shield"), 2, inplace=True)

    df_full = df.unstack(level="eval_shield")
    df_full.columns = df_full.columns.get_level_values("eval_shield").values

    df_diag = pd.concat((df.index.to_frame(), df), axis=1)
    df_diag = df_diag[df_diag["shield"] == df_diag["eval_shield"]]
    del df_diag["shield"]
    del df_diag["eval_shield"]
    del df_diag["observe_momentum"]
    del df_diag["random_start"]

    df_diag = df_diag.reset_index(level=2, drop=True).unstack().rename(columns=name_map)
    df_diag.columns = df_diag.columns.get_level_values(1)

    return df_full, df_diag


def format_to_latex(table, name, **override):
    latex_format = {
        "multirow_align": "c",
        "column_format": "lllll",
        "hrules": True,
        "clines": "skip-last;data"
    }

    latex_format.update(override)

    Styler(table).to_latex(name + ".tex", **latex_format)


if __name__ == '__main__':
    api = wandb.Api()

    # particle_full, particle_diag = run_table_particle(api.runs("dmelcer9/Centralized-Verification-Particle-Momentum"))
    # format_to_latex(particle_full, "particle_full", column_format="llllll")
    # format_to_latex(particle_diag, "particle_diag")

    particle_dnn_full, particle_dnn_diag = run_table_particle(
        api.runs("dmelcer9/Centralized-Verification-Particle-Momentum",
                 filters={"tags": "parallel_configs/IndivQLearningParticleDNNOnlyOneHotEval.csv",
                          "state": "finished"}))
    format_to_latex(particle_dnn_full, "particle_dnn_full", column_format="llllll")
    format_to_latex(particle_dnn_diag, "particle_dnn_diag")

    no_rand_full, no_rand_diag = run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Sweep",
                                                    filters={"config.randomize_starts": "False"}))

    format_to_latex(no_rand_full, "slugs_sweep_no_random_start")

    rand_full, rand_diag = run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Sweep",
                                              filters={"config.randomize_starts": "True"}))

    format_to_latex(rand_full, "slugs_sweep_random_start")

    format_to_latex(pd.concat((no_rand_diag, rand_diag), keys=("Fixed", "Random"), names=["Start Type"]),
                    "gridworld_diag")

    """
    Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-1k-Episodes",
                              filters={"config.randomize_starts": "False"}))).to_latex("thousand_eps.tex",
                                                                                       **latex_format)


    Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Shield-Sweep-1",
                              filters={"config.randomize_starts": "False"}))).to_latex("thousand_eps.tex",
                                                                                       **latex_format)
                                                                                       
       Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Terminate-On-Collision",
                              filters={"config.randomize_starts": "False"}))).to_latex(
        "slugs_sweep_no_random_start_term.tex",
        **latex_format)

    Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Terminate-On-Collision",
                              filters={"config.randomize_starts": "True"}))).to_latex(
        "slugs_sweep_random_start_term.tex",
        **latex_format)
        
    Styler(run_table(
        api.runs("dmelcer9/Centralized-Verification-Sweep-2", filters={"config.randomize_starts": "False"}))).to_latex(
        "million_steps.tex", **latex_format)
    Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Rand-Start-25"))).to_latex(
        "rand_start_2million_steps.tex", **latex_format)
        
        Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Decentralized-Consistent-Ordering",
                              filters={"config.randomize_starts": "False"}))).to_latex(
        "slugs_sweep_no_random_start_co.tex",
        **latex_format)

    Styler(run_table(api.runs("dmelcer9/Centralized-Verification-Slugs-Decentralized-Consistent-Ordering",
                              filters={"config.randomize_starts": "True"}))).to_latex("slugs_sweep_random_start_co.tex",
                                                                                      **latex_format)

    """
